{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Test [U-Net](https://arxiv.org/abs/1505.04597)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Test a pretrained U-Net model on validation set."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# this is for path management in jupyter notebook\n",
    "# note necessary if you're running in terminal or other IDEs\n",
    "import os\n",
    "import sys\n",
    "module_path = os.path.abspath(os.path.join('..'))\n",
    "if module_path not in sys.path:\n",
    "    sys.path.append(module_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Define the Network"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# imports and parameter settings\n",
    "import tensorflow as tf\n",
    "from nn import unet\n",
    "class_num = 2                 # class number in ground truth\n",
    "patch_size = (572, 572)       # input patch size\n",
    "lr = 1e-4                     # start learning rate\n",
    "ds = 60                       # #epochs before lr decays\n",
    "dr = 0.1                      # lr will decay to lr*dr\n",
    "epochs = 1                    # #epochs to train\n",
    "bs = 5                        # batch size\n",
    "suffix = 'test'               # user customize name for the network"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# define network\n",
    "tf.reset_default_graph()\n",
    "unet = unet.UNet(class_num, patch_size, suffix=suffix, learn_rate=lr, decay_step=ds, decay_rate=dr,\n",
    "                 epochs=epochs, batch_size=bs)\n",
    "overlap = unet.get_overlap()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Prepare the Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# imports and parameter settings\n",
    "import numpy as np\n",
    "from collection import collectionMaker, collectionEditor\n",
    "ds_name = 'Inria'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "gt_d255 might already exist, skip replacement\n",
      "=========================================Inria==========================================\n",
      "raw_data_path: /media/ei-edl01/data/uab_datasets/inria/data/Original_Tiles\n",
      "field_name: ['austin', 'chicago', 'kitsap', 'tyrol-w', 'vienna']\n",
      "field_id: [0, 1, ..., 36]\n",
      "clc_name: Inria\n",
      "tile_dim: (5000, 5000)\n",
      "chan_mean: [ 103.2341876   108.95194825  100.14192002]\n",
      "Source file: *RGB*.tif\n",
      "GT file: *gt_d255*.png\n",
      "========================================================================================\n"
     ]
    }
   ],
   "source": [
    "cm = collectionMaker.read_collection(raw_data_path=r'/media/ei-edl01/data/uab_datasets/inria/data/Original_Tiles',\n",
    "                                     field_name='austin,chicago,kitsap,tyrol-w,vienna',\n",
    "                                     field_id=','.join(str(i) for i in range(37)),\n",
    "                                     rgb_ext='RGB',\n",
    "                                     gt_ext='GT',\n",
    "                                     file_ext='tif',\n",
    "                                     force_run=False,\n",
    "                                     clc_name=ds_name)\n",
    "gt_d255 = collectionEditor.SingleChanMult(cm.clc_dir, 1/255, ['GT', 'gt_d255']).\\\n",
    "    run(force_run=False, file_ext='png', d_type=np.uint8,)\n",
    "cm.replace_channel(gt_d255.files, True, ['GT', 'gt_d255'])\n",
    "cm.print_meta_data()\n",
    "file_list_train = cm.load_files(field_id=','.join(str(i) for i in range(6, 37)), field_ext='RGB,gt_d255')\n",
    "file_list_valid = cm.load_files(field_id=','.join(str(i) for i in range(6)), field_ext='RGB,gt_d255')\n",
    "chan_mean = cm.meta_data['chan_mean'][:3]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Test the Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from nn import nn_utils\n",
    "tile_size = (5000, 5000)\n",
    "gpu = 0\n",
    "sfn = 32"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Summary of results:\n",
      "===================austin=58.06===================\n",
      "austin1=60.47\n",
      "austin2=63.45\n",
      "austin3=55.76\n",
      "austin4=52.61\n",
      "austin5=54.75\n",
      "==================chicago=60.59===================\n",
      "chicago1=60.98\n",
      "chicago2=64.59\n",
      "chicago3=59.27\n",
      "chicago4=57.73\n",
      "chicago5=57.33\n",
      "===================kitsap=46.27===================\n",
      "kitsap1=54.02\n",
      "kitsap2=54.50\n",
      "kitsap3=57.55\n",
      "kitsap4=24.15\n",
      "kitsap5=27.12\n",
      "==================tyrol-w=52.88===================\n",
      "tyrol-w1=52.91\n",
      "tyrol-w2=51.78\n",
      "tyrol-w3=54.09\n",
      "tyrol-w4=50.72\n",
      "tyrol-w5=54.09\n",
      "===================vienna=68.81===================\n",
      "vienna1=68.54\n",
      "vienna2=73.19\n",
      "vienna3=68.54\n",
      "vienna4=69.62\n",
      "vienna5=62.59\n",
      "==================Overall=61.86===================\n"
     ]
    }
   ],
   "source": [
    "nn_utils.tf_warn_level(3) # Verbose information will not be displayed\n",
    "model_dir = r'/hdd6/Models/test/unet_test_PS(572, 572)_BS5_EP6_LR0.0001_DS60_DR0.1'\n",
    "tile_dict, field_dict, overall = unet.evaluate(file_list_valid, patch_size, tile_size, bs, chan_mean, model_dir, gpu, \n",
    "                                               save_result_parent_dir='ersa', sfn=sfn, force_run=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [default]",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
